{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Query Vision Language Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {},
   "source": [
    "## Querying Qwen-VL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()  # Run this first.\n",
    "\n",
    "model_path = \"Qwen/Qwen2.5-VL-3B-Instruct\"\n",
    "chat_template = \"qwen2-vl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lets create a prompt.\n",
    "\n",
    "from io import BytesIO\n",
    "import requests\n",
    "from PIL import Image\n",
    "\n",
    "from sglang.srt.parser.conversation import chat_templates\n",
    "\n",
    "image = Image.open(\n",
    "    BytesIO(\n",
    "        requests.get(\n",
    "            \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
    "        ).content\n",
    "    )\n",
    ")\n",
    "\n",
    "conv = chat_templates[chat_template].copy()\n",
    "conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
    "conv.append_message(conv.roles[1], \"\")\n",
    "conv.image_data = [image]\n",
    "\n",
    "print(conv.get_prompt())\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {},
   "source": [
    "### Query via the offline Engine API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sglang import Engine\n",
    "\n",
    "llm = Engine(\n",
    "    model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
    "print(out[\"text\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {},
   "source": [
    "### Query via the offline Engine API, but send precomputed embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the image embeddings using Huggingface.\n",
    "\n",
    "from transformers import AutoProcessor\n",
    "from transformers import Qwen2_5_VLForConditionalGeneration\n",
    "\n",
    "processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
    "vision = (\n",
    "    Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_prompt = processor(\n",
    "    images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
    ")\n",
    "input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
    "precomputed_embeddings = vision(\n",
    "    processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n",
    ")\n",
    "\n",
    "mm_item = dict(\n",
    "    modality=\"IMAGE\",\n",
    "    image_grid_thw=processed_prompt[\"image_grid_thw\"],\n",
    "    precomputed_embeddings=precomputed_embeddings,\n",
    ")\n",
    "out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
    "print(out[\"text\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "## Querying Llama 4 (Vision)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()  # Run this first.\n",
    "\n",
    "model_path = \"meta-llama/Llama-4-Scout-17B-16E-Instruct\"\n",
    "chat_template = \"llama-4\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lets create a prompt.\n",
    "\n",
    "from io import BytesIO\n",
    "import requests\n",
    "from PIL import Image\n",
    "\n",
    "from sglang.srt.parser.conversation import chat_templates\n",
    "\n",
    "image = Image.open(\n",
    "    BytesIO(\n",
    "        requests.get(\n",
    "            \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n",
    "        ).content\n",
    "    )\n",
    ")\n",
    "\n",
    "conv = chat_templates[chat_template].copy()\n",
    "conv.append_message(conv.roles[0], f\"What's shown here: {conv.image_token}?\")\n",
    "conv.append_message(conv.roles[1], \"\")\n",
    "conv.image_data = [image]\n",
    "\n",
    "print(conv.get_prompt())\n",
    "print(f\"Image size: {image.size}\")\n",
    "\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "### Query via the offline Engine API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sglang.test.test_utils import is_in_ci\n",
    "\n",
    "if not is_in_ci():\n",
    "    from sglang import Engine\n",
    "\n",
    "    llm = Engine(\n",
    "        model_path=model_path,\n",
    "        trust_remote_code=True,\n",
    "        enable_multimodal=True,\n",
    "        mem_fraction_static=0.8,\n",
    "        tp_size=4,\n",
    "        attention_backend=\"fa3\",\n",
    "        context_length=65536,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not is_in_ci():\n",
    "    out = llm.generate(prompt=conv.get_prompt(), image_data=[image])\n",
    "    print(out[\"text\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16",
   "metadata": {},
   "source": [
    "### Query via the offline Engine API, but send precomputed embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not is_in_ci():\n",
    "    # Compute the image embeddings using Huggingface.\n",
    "\n",
    "    from transformers import AutoProcessor\n",
    "    from transformers import Llama4ForConditionalGeneration\n",
    "\n",
    "    processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n",
    "    model = Llama4ForConditionalGeneration.from_pretrained(\n",
    "        model_path, torch_dtype=\"auto\"\n",
    "    ).eval()\n",
    "    vision = model.vision_model.cuda()\n",
    "    multi_modal_projector = model.multi_modal_projector.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not is_in_ci():\n",
    "    processed_prompt = processor(\n",
    "        images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
    "    )\n",
    "    print(f'{processed_prompt[\"pixel_values\"].shape=}')\n",
    "    input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
    "\n",
    "    image_outputs = vision(\n",
    "        processed_prompt[\"pixel_values\"].to(\"cuda\"), output_hidden_states=False\n",
    "    )\n",
    "    image_features = image_outputs.last_hidden_state\n",
    "    vision_flat = image_features.view(-1, image_features.size(-1))\n",
    "    precomputed_embeddings = multi_modal_projector(vision_flat)\n",
    "\n",
    "    mm_item = dict(modality=\"IMAGE\", precomputed_embeddings=precomputed_embeddings)\n",
    "    out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
    "    print(out[\"text\"])"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "custom_cell_magics": "kql",
   "encoding": "# -*- coding: utf-8 -*-"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
